import taichi as ti

from ti_rt.color.space import lRGB2sRGB

from .material import Material
from .shape import Shape
from .utils import Inf, Vec3, Vec4, skyLike, Ray4
from .color import sRGB2lRGB, Spectrum2RGB, Hero
from .config import RELATIVISTIC, SPECTRUM, FRONT_LIGHT, LIMITED_LIGHT

__all__ = [
    "Scene",
]


class Scene:
    """场景类, 可放置多个 Shape"""

    def __init__(self, lst: "list[Shape]" = None, background=skyLike) -> None:
        self.objects = [] if lst is None else lst
        self.index = ti.field(int, len(self.objects))
        for i, obj in enumerate(self.objects):
            self.index[i] = obj.index
        self.background = background
        self.sensor = Spectrum2RGB()
        self.lambdaShift = ti.field(float, shape=1)

    if SPECTRUM:

        @ti.func
        def integral(self, ray: Ray4, minBounce: int, maxBounce: int) -> Vec3:
            ret = Vec3(0)
            throughout = Hero(1)
            lamda = self.sensor.rand() + self.lambdaShift[0]
            for i in range(maxBounce):
                end, color, ray, coef1, coef2 = self.hit(ray)
                scale = color.norm()
                if scale < 1e-6:
                    break
                # λ = c / ν
                throughout *= (
                    self.sensor.heroSampleInv(color / scale, lamda, 1 / coef1) * scale
                )
                if ti.static(FRONT_LIGHT):
                    throughout *= coef2
                if end:
                    if i >= minBounce:
                        ret = self.sensor.heroSample(throughout, lamda)
                    break
            return ret

    else:

        @ti.func
        def integral(self, ray: Ray4, minBounce: int, maxBounce: int) -> Vec3:
            ret = Vec3(0)
            throughout = Vec3(1)
            for i in range(maxBounce):
                end, color, ray, _, coef2 = self.hit(ray)
                scale = color.norm()
                if scale < 1e-6:
                    break
                throughout *= color
                if ti.static(FRONT_LIGHT):
                    throughout *= coef2
                if end:
                    if i >= minBounce:
                        ret = throughout
                    break
            return ret

    @ti.func
    def hit(self, ray: Ray4):
        """光线与场景求交, 返回是否结束追踪、光线颜色和出射光线"""
        closest, normal = Inf, Vec3(0)
        this = -1
        # 遍历场景内模型
        if not LIMITED_LIGHT:
            ray.direct.w = 0
        for i in range(self.index.shape[0]):
            # 取需要加载的模型编号
            index = self.index[i]
            # 当前物体的光线，此时为物体局部光线
            rayObj = Shape.shapes[index].ray2obj(ray)
            # 把所有光线（4维向量）转为3维后处理求与物体的交
            # 距离dist为光源点到该光线和物体交点的距离
            dist, norm = Shape.handleAllIntersect(index, rayObj.ray3(), closest)
            # 如果当前距离小于最小值，更新最小值
            if dist < closest:
                closest = dist
                normal = norm
                # 记录下当前击中物体
                this = index

        # end是否结束光线反射，color与物体交点颜色
        end, color, lambdaCoef, powerCoef = True, Vec3(0), 1.0, 1.0

        # 如果最近距离不是无穷远，说明已经光线击中物体
        if closest < Inf:
            # 当前光线距离最近物体的光线
            rayObj = Shape.shapes[this].ray2obj(ray)
            # 击中点 = 光线（3维化）后交点
            hitPoint = rayObj.ray3().at(closest)
            end, color, ray_ = Material.handleAllScatter(
                Shape.shapes[this].material, rayObj.ray3(), hitPoint, normal
            )
            # 下次光线追踪起点的时间
            nextT = rayObj.origin.w
            # 相对论模式下
            if ti.static(RELATIVISTIC):
                # 下次光线的时间（4维时空）减去（击中点位置减光线原点坐标）向量化
                if LIMITED_LIGHT:
                    nextT -= (hitPoint - rayObj.origin.xyz).norm()
                β = Shape.shapes[this].velocity
                absβ = β.norm()
                βCosθ = -ray.direct.xyz.normalized().dot(β)
                lambdaCoef = (1 - βCosθ) / (1 - absβ**2) ** 0.5
                powerCoef = (1 - absβ**2) ** 0.5 / (1 - βCosθ)
                if absβ > 0:
                    cosθ = βCosθ / absβ
                    cosθ_ = rayObj.direct.xyz.normalized().dot(β) / absβ
                    sinθ = (1 - cosθ**2) ** 0.5
                    sinθ_ = (1 - cosθ_**2) ** 0.5
                    if sinθ > 0:
                        powerCoef *= sinθ_ / sinθ
            # 得到折射一次的全局光线
            ray = ray_.ray4(nextT)
            if not LIMITED_LIGHT:
                ray.direct.w = 0
            ray = Shape.shapes[this].ray2ether(ray)
        else:
            color = sRGB2lRGB(self.background(ray.ray3()))
        # 返回是否结束光线追踪（当前光没有击中石头或天空或光源），光线交物体交点的颜色，弹回来的光线
        return end, color, ray, lambdaCoef, powerCoef

    @ti.func
    def oneHit(self, ray: Ray4) -> Vec3:
        _, color, ray, coef1, coef2 = self.hit(ray)
        if ti.static(SPECTRUM):
            lamda = self.sensor.rand() + self.lambdaShift[0]
            scale = color.norm()
            # λ = c / ν
            spec *= self.sensor.heroSampleInv(color / scale, lamda, coef1) * scale
            color = self.sensor.heroSample(spec, lamda)
        if ti.static(FRONT_LIGHT):
            color *= coef2
        return color

    @ti.func
    def closest(self, ray: Ray4) -> Vec3:
        """根据首个交点的距离来为画面着色"""
        closest = Inf
        color = Vec3(0)
        for i in range(self.index.shape[0]):
            index = self.index[i]
            rayObj = Shape.shapes[index].ray2obj(ray).ray3()
            dist, _ = Shape.handleAllIntersect(index, rayObj, closest)
            closest = min(closest, dist)
        interval = 16
        if closest < Inf:
            r = closest
            g = closest - interval
            b = closest - interval * 2
            color = 1 - ti.math.clamp(abs(Vec3(r, g, b)) / interval, 0, 1)
        return color

    @ti.func
    def firstNormal(self, ray: Ray4) -> Vec3:
        """根据首个交点处的法线来为画面着色, 三轴坐标分别对应RGB"""
        closest = Inf
        normal = Vec3(0)
        for i in range(self.index.shape[0]):
            index = self.index[i]
            rayObj = Shape.shapes[index].ray2obj(ray).ray3()
            dist, norm = Shape.handleAllIntersect(index, rayObj, closest)
            if dist < closest:
                closest = dist
                if norm.dot(rayObj.direct) > 0:
                    norm = -norm
                normal = Shape.shapes[index].vec2ether(Vec4(norm, 0)).xyz.normalized()
        return normal * 0.5 + 0.5
